package sql

import (
	"errors"
	"fmt"
	"strconv"
	"strings"

	"gitee.com/haodreams/libs/easy"
	"gitee.com/haodreams/golib/logs"
)

//只支持基本的sql语句

const (
	//CmdInsert 插入
	CmdInsert = 1
	//CmdSelect 查询
	CmdSelect = 2
	//CmdUpdate 更新
	CmdUpdate = 3
	//CmdDelete 删除
	CmdDelete = 4
)

// SelectCommand 字符串
type SelectCommand struct {
	SQL    string
	Fields []string
	Table  string
	Where  string
	Limit  []int //100
}

// UpdateCommand 更新
type UpdateCommand struct {
	SQL    string
	Fields []string
	Table  string
	Where  string
}

// DeleteCommand 删除
type DeleteCommand struct {
	SQL   string
	Table string
	Where string
}

// ParserSQL 解析SQL
func ParserSQL(s string) (cmd interface{}, err error) {
	defer func() {
		if e := recover(); e != nil {
			logs.Info(e)
			err = errors.New(fmt.Sprint(e))
		}
	}()
	s = strings.TrimSuffix(s, ";")
	line := strings.ToUpper(s)
	fields := strings.Fields(line)
	switch fields[0] {
	case "SELECT":
		idx := easy.IndexArray(fields, "FROM")
		if idx < 0 {
			err = errors.New("缺少关键字：'FROM'")
			return
		}
		if idx < 2 {
			err = errors.New("缺少需要查询的字段")
			return
		}
		if idx+1 >= len(fields) {
			err = errors.New("缺少表名")
			return
		}

		cmd := new(SelectCommand)
		cmd.SQL = s
		cmd.Table = fields[idx+1]
		cmd.Fields = fields[1:idx]

		iLimit := easy.IndexArray(fields[idx+1:], "LIMIT")
		if iLimit >= 0 {
			iLimit += idx + 1
			if iLimit+1 < len(fields) {
				ss := strings.Split(fields[iLimit+1], ",")
				if len(ss) == 2 {
					begin, err := strconv.Atoi(ss[0])
					if err != nil {
						err = errors.New("LIMIT error, " + err.Error())
						return nil, err
					}
					limit, err := strconv.Atoi(ss[1])
					if err != nil {
						err = errors.New("LIMIT error, " + err.Error())
						return nil, err
					}
					if begin < 1 {
						begin = 1
					}
					if limit < 1 {
						limit = 1
					}
					cmd.Limit = []int{begin, limit}
				} else {
					limit, err := strconv.Atoi(ss[0])
					if err != nil {
						err = errors.New("LIMIT error, " + err.Error())
						return nil, err
					}
					cmd.Limit = []int{1, limit}
				}
			}
		} else {
			iLimit = len(fields)
		}

		idx = strings.Index(line, "WHERE")
		if idx >= 0 {
			limit := strings.Index(line, "LIMIT")
			if limit > 0 {
				cmd.Where = s[idx+6 : limit]
			} else {
				cmd.Where = s[idx+6:]
			}
		}
		return cmd, err
	case "INSERT":
	case "UPDATE":
	case "DELETE":
	default:
		err = errors.New("无效的关键字：" + fields[0])
	}
	return
}
